{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CjgHx8VbcL26"
      },
      "outputs": [],
      "source": [
        "import nlpaug.augmenter.char as nac\n",
        "\n",
        "import nlpaug.augmenter.word as naw\n",
        "\n",
        "import nlpaug.augmenter.sentence as nas\n",
        "\n",
        "import nlpaug.flow as nafc\n",
        "\n",
        "from nlpaug.util import Action"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "imdb_train= load_dataset('imdb',\n",
        "\n",
        "split=\"train[:1000]+train[-1000:]\")\n",
        "\n",
        "imdb_test= load_dataset('imdb',\n",
        "\n",
        "split=\"test[:500]+test[-500:]\")\n",
        "\n",
        "imdb_val= load_dataset('imdb',\n",
        "\n",
        "split=\"test[500:1000]+test[-1000:-500]\")\n",
        "\n",
        "imdb_train.shape, imdb_test.shape, imdb_val.shape"
      ],
      "metadata": {
        "id": "oqjwHUZgcS8X"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.metrics import (accuracy_score,\n",
        "\n",
        "precision_recall_fscore_support)\n",
        "\n",
        "def compute_metrics(pred):\n",
        "  labels = pred.label_ids\n",
        "  preds = pred.predictions.argmax(-1)\n",
        "  precision, recall, f1, _ = precision_recall_fscore_support(labels,preds, average='macro')\n",
        "  acc = accuracy_score(labels, preds)\n",
        "  return {\n",
        "  'Accuracy': acc,\n",
        "  'F1': f1,\n",
        "  'Precision': precision,\n",
        "  'Recall': recall\n",
        "  }\n",
        "\n",
        "def tokenize_it(e):\n",
        "\n",
        "  return tokenizer(e['text'],\n",
        "  padding=True,\n",
        "  truncation=True)"
      ],
      "metadata": {
        "id": "NUM7RgOLcWhB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import nlpaug.augmenter.word as naw\n",
        "\n",
        "#substitute character by keyboard distance\n",
        "\n",
        "aug1 = nac.KeyboardAug(aug_word_p=0.2,\n",
        "\n",
        "aug_char_max=2,\n",
        "\n",
        "aug_word_max=4)\n",
        "\n",
        "# random insert/swap/delete\n",
        "\n",
        "aug2 = nac.RandomCharAug(action=\"insert\", aug_char_max=1)\n",
        "\n",
        "aug3 = nac.RandomCharAug(action=\"swap\", aug_char_max=1)\n",
        "\n",
        "aug4 = nac.RandomCharAug(action=\"delete\", aug_char_max=1)\n",
        "\n",
        "# spelling error\n",
        "\n",
        "aug5 = naw.SpellingAug()\n",
        "\n",
        "# contextual word insertion / substitute\n",
        "\n",
        "aug6 = naw.ContextualWordEmbsAug(\n",
        "\n",
        "model_path='bert-base-uncased',\n",
        "\n",
        "action=\"insert\")\n",
        "\n",
        "aug7 = naw.ContextualWordEmbsAug(\n",
        "\n",
        "model_path='bert-base-uncased',\n",
        "\n",
        "action=\"substitute\")\n",
        "\n",
        "# wordnet-based synonym replacement\n",
        "\n",
        "aug8 = naw.SynonymAug(aug_src='wordnet')\n",
        "\n",
        "# random word deletion\n",
        "\n",
        "aug9 = naw.RandomWordAug()\n",
        "\n",
        "# back-translation\n",
        "\n",
        "aug10 = naw.BackTranslationAug(\n",
        "\n",
        "from_model_name='facebook/wmt19-en-de',\n",
        "\n",
        "to_model_name='facebook/wmt19-de-en', device='cuda')"
      ],
      "metadata": {
        "id": "KCXRyPIBcqWW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def augment_it(text, label):\n",
        "\n",
        "  result= [eval(\"aug\"+str(i)).augment(text)[0] for i in range(1,11) ]\n",
        "\n",
        "  return result, [label]* len(result)"
      ],
      "metadata": {
        "id": "ZKvzQEo5csBR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "\n",
        "imdb_df=pd.DataFrame(imdb_train)\n",
        "\n",
        "texts=[]\n",
        "\n",
        "labels=[]\n",
        "\n",
        "for r in imdb_df.sample(frac=0.1).itertuples(index=False):\n",
        "  t,l=augment_it(r.text, r.label)\n",
        "\n",
        "  texts+= t\n",
        "\n",
        "  labels+=l\n",
        "\n",
        "aug_df=pd.DataFrame()\n",
        "\n",
        "aug_df[\"text\"]= texts\n",
        "\n",
        "aug_df[\"label\"]= labels\n",
        "\n",
        "imdb_augmented=pd.concat([imdb_df, aug_df])\n",
        "\n",
        "imdb_df.shape, imdb_augmented.shape"
      ],
      "metadata": {
        "id": "2scluOVAcxPY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import BertTokenizerFast, BertForSequenceClassification\n",
        "\n",
        "model_path= 'bert-base-uncased'\n",
        "\n",
        "tokenizer = BertTokenizerFast.from_pretrained(model_path)\n",
        "\n",
        "#imdb train data with augmentation\n",
        "\n",
        "imdb_augmented2= Dataset.from_pandas(imdb_augmented)\n",
        "\n",
        "enc_train=imdb_augmented2.map(tokenize_it,  batched=True, batch_size=1000)\n",
        "\n",
        "# imdb train data without augmentation\n",
        "\n",
        "enc_train=imdb_train.map(tokenize_it,  batched=True, batch_size=1000)\n",
        "\n",
        "enc_test=imdb_test.map(tokenize_it,  batched=True, batch_size=1000)\n",
        "\n",
        "enc_val=imdb_val.map(tokenize_it,\n",
        "\n",
        "batched=True, batch_size=1000)\n",
        "\n",
        "model_path= \"bert-base-uncased\"\n",
        "\n",
        "model = BertForSequenceClassification.from_pretrained(model_path,  id2label={0:\"NEG\", 1:\"POS\"},  label2id={\"NEG\":0, \"POS\":1})"
      ],
      "metadata": {
        "id": "F2PhunQ7c5IM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import TrainingArguments, Trainer\n",
        "\n",
        "training_args = TrainingArguments( output_dir='./MyIMDBModel',  do_train=True, do_eval=True, num_train_epochs=3,  per_device_train_batch_size=16,  per_device_eval_batch_size=16, fp16=True, load_best_model_at_end=True)\n",
        "\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=enc_train,\n",
        "    eval_dataset=enc_val,\n",
        "    compute_metrics= compute_metrics)\n",
        "\n",
        "trainer.train()\n",
        "\n",
        "q=[trainer.evaluate(eval_dataset=data) for data in [enc_train, enc_val, enc_test]]\n",
        "\n",
        "pd.DataFrame(q, index=[\"train\",\"val\",\"test\"]).iloc[:,:5]"
      ],
      "metadata": {
        "id": "8uh48DRCdQ_c"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}